{
  "metadata": {
    "custom": {
      "cells": [],
      "metadata": {
        "accelerator": "GPU",
        "colab": {
          "background_execution": "on",
          "collapsed_sections": [],
          "machine_shape": "hm",
          "name": "Torchrec Introduction.ipynb",
          "provenance": []
        },
        "fileHeader": "",
        "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00",
        "interpreter": {
          "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
        },
        "isAdHoc": false,
        "kernelspec": {
          "display_name": "Python 3.7.13 ('torchrec': conda)",
          "language": "python",
          "name": "python3"
        },
        "language_info": {
          "codemirror_mode": {
            "name": "ipython",
            "version": 3
          },
          "file_extension": ".py",
          "mimetype": "text/x-python",
          "name": "python",
          "nbconvert_exporter": "python",
          "pygments_lexer": "ipython3",
          "version": "3.7.13"
        }
      },
      "nbformat": 4,
      "nbformat_minor": 0
    },
    "indentAmount": 2,
    "last_server_session_id": "e11f329f-b395-4702-9b33-449716ea422e",
    "last_kernel_id": "b6fe1a08-1d4d-40cd-afe6-8352c4e42d25",
    "last_base_url": "https://bento.edge.x2p.facebook.net/",
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "last_msg_id": "c02547e3-e4c072dc430f066c4d18479a_594",
    "captumWidgetMessage": [],
    "outputWidgetContext": [],
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "accelerator": "GPU",
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0,
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBgIy9eYYx35",
        "originalKey": "4766a371-bf6e-4342-98fb-16dde5255d73",
        "outputsInitialized": false,
        "language": "markdown",
        "showInput": false
      },
      "source": [
        "## **Open Source Installation** (For Reference)\n",
        "Requirements:\n",
        "- python >= 3.9\n",
        "\n",
        "We highly recommend CUDA when using TorchRec. If using CUDA:\n",
        "- cuda >= 11.8\n",
        "\n",
        "Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sFYvP95xaAER",
        "originalKey": "27d22c43-9299-46ec-94f2-28a880546fe3",
        "outputsInitialized": true,
        "language": "python",
        "customOutput": null,
        "executionStartTime": 1726000131275,
        "executionStopTime": 1726000131459,
        "serverExecutionDuration": 2.2683702409267,
        "requestMsgId": "27d22c43-9299-46ec-94f2-28a880546fe3"
      },
      "source": [
        "# Install stable versions for best reliability\n",
        "\n",
        "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U\n",
        "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121\n",
        "!pip3 install torchmetrics==1.0.3\n",
        "!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-4DFtQNDYao1",
        "originalKey": "07e2a5ae-9ca2-45d7-af10-84d8e09ce91e",
        "outputsInitialized": false,
        "language": "markdown",
        "showInput": false
      },
      "source": [
        "# Intro to TorchRec\n",
        "\n",
        "### Embeddings\n",
        "When building recommendation systems, categorical features typically have massive cardinalities, posts, users, ads, etc.\n",
        "\n",
        "In order to represent these entities and model these relationships, **embeddings** are used. In machine learning, **embeddings are a vectors of real numbers in a high-dimensional space used to represent meaning in complex data like words, images, or users**.\n",
        "\n",
        "\n",
        "### Embeddings in RecSys\n",
        "\n",
        "Now you might wonder, how are these embeddings generated in the first place? Well, embeddings are represented as individual rows in an **Embedding Table**, also referred to as embedding weights. The reason for this is that embeddings/embedding table weights are trained just like all of the other weights of the model via gradient descent!\n",
        "\n",
        "Embedding tables are simply a large matrix for storing embeddings, with two dimensions (B, N), where\n",
        "* B is the number of embeddings stored by the table\n",
        "* N is the number of dimensions per embedding (N-dimensional embedding).\n",
        "\n",
        "\n",
        "The inputs to embedding tables represent embedding lookups to retrieve the embedding for a specific index/row. In recommendation systems, such as those used in Meta, unique IDs are not only used for specific users, but also across entites like posts and ads to serve as lookup indices to respective embedding tables!\n",
        "\n",
        "Embeddings are trained in RecSys through the following process:\n",
        "1. **Input/lookup indices are fed into the model, as unique IDs**. IDs are hashed to the total size of the embedding table to prevent issues when the ID > # of rows\n",
        "2. Embeddings are then retrieved and **pooled, such as taking the sum or mean of the embeddings**. This is required as there can be a variable # of embeddings per example while the model expects consistent shapes.\n",
        "3. The **embeddings are used in conjunction with the rest of the model to produce a prediction**, such as [Click-Through Rate (CTR)](https://support.google.com/google-ads/answer/2615875?hl=en) for an Ad.\n",
        "4. The loss is calculated with the prediction and the label for an example, and **all weights of the model are updated through gradient descent and backpropogation, including the embedding weights** that were associated with the example.\n",
        "\n",
        "These embeddings are crucial for representing categorical features, such as users, posts, and ads, in order to capture relationships and make good recommendations. Meta AI's [Deep learning recommendation model](https://arxiv.org/abs/1906.00091) (DLRM) paper talks more about the technical details of using embedding tables in RecSys.\n",
        "\n",
        "This tutorial will introduce the concept of embeddings, showcase TorchRec specific modules/datatypes, and depict how distributed training works with TorchRec."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "48b50971-aeab-4754-8cff-986496689f43",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000131464,
        "executionStopTime": 1726000133971,
        "serverExecutionDuration": 2349.9959111214,
        "requestMsgId": "48b50971-aeab-4754-8cff-986496689f43",
        "customOutput": null,
        "outputsInitialized": true,
        "output": {
          "id": "1534047040582458"
        },
        "id": "AbeT4W9xcso9"
      },
      "source": [
        "import torch"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "4b510f99-840d-4986-b635-33c21af48cf4",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "bjuDdEqocso-"
      },
      "source": [
        "## Embeddings in PyTorch\n",
        "[`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html): Embedding table where forward pass returns the embeddings themselves as is.\n",
        "\n",
        "[`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html): Embedding table where forward pass returns embeddings that are then pooled, i.e. sum or mean. Otherwise known as **Pooled Embeddings**\n",
        "\n",
        "In this section, we will go over a very brief introduction with doing embedding lookups through passing in indices into the table. Check out the links for each for more sophisticated use cases and experiments!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000133982,
        "executionStopTime": 1726000134201,
        "serverExecutionDuration": 31.60185739398,
        "requestMsgId": "06ebfce4-bc22-4f5a-97d7-7a8f5d8ac375",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "933119035309629"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1X5C_Dnccso-",
        "outputId": "616cc153-67ee-4dd6-b1ab-ee6ff6f44709"
      },
      "source": [
        "num_embeddings, embedding_dim = 10, 4\n",
        "\n",
        "# Initialize our embedding table\n",
        "weights = torch.rand(num_embeddings, embedding_dim)\n",
        "print(\"Weights:\", weights)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Weights: tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
            "        [0.1830, 0.0326, 0.8241, 0.2995],\n",
            "        [0.7328, 0.0531, 0.9528, 0.0592],\n",
            "        [0.7800, 0.1797, 0.0167, 0.7401],\n",
            "        [0.4837, 0.2052, 0.3360, 0.9656],\n",
            "        [0.7887, 0.3066, 0.0956, 0.3344],\n",
            "        [0.5904, 0.8541, 0.5963, 0.2800],\n",
            "        [0.5751, 0.4341, 0.6218, 0.4101],\n",
            "        [0.6881, 0.5363, 0.4747, 0.2301],\n",
            "        [0.6088, 0.1060, 0.1100, 0.7290]])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "b2f21375-8d36-487f-b0c3-ff8a5df950a4",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000134203,
        "executionStopTime": 1726000134366,
        "serverExecutionDuration": 8.956927806139,
        "requestMsgId": "b2f21375-8d36-487f-b0c3-ff8a5df950a4",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "831419729143778"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bxszzeGdcso-",
        "outputId": "88f21deb-c2f1-4894-975e-fdac41436b36"
      },
      "source": [
        "# Pass in pre generated weights just for example, typically weights are randomly initialized\n",
        "embedding_collection = torch.nn.Embedding(\n",
        "    num_embeddings, embedding_dim, _weight=weights\n",
        ")\n",
        "embedding_bag_collection = torch.nn.EmbeddingBag(\n",
        "    num_embeddings, embedding_dim, _weight=weights\n",
        ")\n",
        "\n",
        "# Print out the tables, we should see the same weights as above\n",
        "print(\"Embedding Collection Table: \", embedding_collection.weight)\n",
        "print(\"Embedding Bag Collection Table: \", embedding_bag_collection.weight)\n",
        "\n",
        "# Lookup rows (ids for embedding ids) from the embedding tables\n",
        "# 2D tensor with shape (batch_size, ids for each batch)\n",
        "ids = torch.tensor([[1, 3]])\n",
        "print(\"Input row IDS: \", ids)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Embedding Collection Table:  Parameter containing:\n",
            "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
            "        [0.1830, 0.0326, 0.8241, 0.2995],\n",
            "        [0.7328, 0.0531, 0.9528, 0.0592],\n",
            "        [0.7800, 0.1797, 0.0167, 0.7401],\n",
            "        [0.4837, 0.2052, 0.3360, 0.9656],\n",
            "        [0.7887, 0.3066, 0.0956, 0.3344],\n",
            "        [0.5904, 0.8541, 0.5963, 0.2800],\n",
            "        [0.5751, 0.4341, 0.6218, 0.4101],\n",
            "        [0.6881, 0.5363, 0.4747, 0.2301],\n",
            "        [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n",
            "Embedding Bag Collection Table:  Parameter containing:\n",
            "tensor([[0.3446, 0.3614, 0.8938, 0.8157],\n",
            "        [0.1830, 0.0326, 0.8241, 0.2995],\n",
            "        [0.7328, 0.0531, 0.9528, 0.0592],\n",
            "        [0.7800, 0.1797, 0.0167, 0.7401],\n",
            "        [0.4837, 0.2052, 0.3360, 0.9656],\n",
            "        [0.7887, 0.3066, 0.0956, 0.3344],\n",
            "        [0.5904, 0.8541, 0.5963, 0.2800],\n",
            "        [0.5751, 0.4341, 0.6218, 0.4101],\n",
            "        [0.6881, 0.5363, 0.4747, 0.2301],\n",
            "        [0.6088, 0.1060, 0.1100, 0.7290]], requires_grad=True)\n",
            "Input row IDS:  tensor([[1, 3]])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "cb5c5906-e9a6-4315-b860-b263e08989be",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000134369,
        "executionStopTime": 1726000134545,
        "serverExecutionDuration": 5.9817284345627,
        "requestMsgId": "cb5c5906-e9a6-4315-b860-b263e08989be",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "2201664893536578"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xkedJeTOcso_",
        "outputId": "46215f2b-03ad-421b-f873-78b2be0df4d4"
      },
      "source": [
        "embeddings = embedding_collection(ids)\n",
        "\n",
        "# Print out the embedding lookups\n",
        "# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above\n",
        "print(\"Embedding Collection Results: \")\n",
        "print(embeddings)\n",
        "print(\"Shape: \", embeddings.shape)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Embedding Collection Results: \n",
            "tensor([[[0.1830, 0.0326, 0.8241, 0.2995],\n",
            "         [0.7800, 0.1797, 0.0167, 0.7401]]], grad_fn=<EmbeddingBackward0>)\n",
            "Shape:  torch.Size([1, 2, 4])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000134547,
        "executionStopTime": 1726000134718,
        "serverExecutionDuration": 7.8675262629986,
        "requestMsgId": "a8e90b32-7c30-41f2-a5b9-bedf2b196e7f",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "6449977515126116"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PmtJkxLccso_",
        "outputId": "c33109b4-205e-492d-e0ca-94d2887ec6e7"
      },
      "source": [
        "# nn.EmbeddingBag default pooling is mean, so should be mean of batch dimension of values above\n",
        "pooled_embeddings = embedding_bag_collection(ids)\n",
        "\n",
        "print(\"Embedding Bag Collection Results: \")\n",
        "print(pooled_embeddings)\n",
        "print(\"Shape: \", pooled_embeddings.shape)\n",
        "\n",
        "# nn.EmbeddingBag is the same as nn.Embedding but just with pooling (mean, sum, etc.)\n",
        "# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection\n",
        "print(\"Mean: \", torch.mean(embedding_collection(ids), dim=1))"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Embedding Bag Collection Results: \n",
            "tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=<EmbeddingBagBackward0>)\n",
            "Shape:  torch.Size([1, 4])\n",
            "Mean:  tensor([[0.4815, 0.1062, 0.4204, 0.5198]], grad_fn=<MeanBackward1>)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "4643305e-2770-40cf-afc6-e64cd3f51063",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "SuCV1cJ8cso_"
      },
      "source": [
        "Congratulations! Now you have a basic understanding on how to use embedding tables --- one of the foundations of modern recommendation systems! These tables represent entities and their relationships. For example, the relationship between a given user and the pages & posts they have liked."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "7dfcffeb-c7c0-4d74-9dba-569c1d882898",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "QIuAYSZ5cso_"
      },
      "source": [
        "# TorchRec\n",
        "\n",
        "Now you know how to use embedding tables, one of the foundations of modern recommendation systems! These tables represent entities and relationships, such as users, pages, posts, etc. Given that these entities are always increasing, a **hash** function is typically applied to make sure the ids are within the bounds of a certain embedding table. However, in order to represent a vast amount of entities and reduce hash collisions, these tables can become quite massive (think about # of ads for example). In fact, these tables can become so massive that they won't be able to fit on 1 GPU, even with 80G of memory!\n",
        "\n",
        "In order to train models with massive embedding tables, sharding these tables across GPUs is required, which then introduces a whole new set of problems/opportunities in parallelism and optimization. Luckily, we have the TorchRec library that has encountered, consolidated, and addressed many of these concerns. TorchRec serves as a **library that provides primitives for large scale distributed embeddings**.\n",
        "\n",
        "From here on out, we will explore the major features of the TorchRec library. We will start with torch.nn.Embedding and will extend that to custom TorchRec modules, explore distributed training environment with generating a sharding plan for embeddings, look at inherent TorchRec optimizations, and extend the model to be ready for inference in C++. Below is a quick outline of what the journey will consist of - buckle in!\n",
        "\n",
        "1. TorchRec Modules and DataTypes\n",
        "2. Distributed Training, Sharding, and Optimizations\n",
        "3. Inference\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "8395ed9c-8336-4686-8e73-cb815b808f2a",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000134724,
        "executionStopTime": 1726000139238,
        "serverExecutionDuration": 4317.9145939648,
        "requestMsgId": "8395ed9c-8336-4686-8e73-cb815b808f2a",
        "outputsInitialized": true,
        "customOutput": null,
        "id": "5vzmNV0IcspA"
      },
      "source": [
        "import torchrec"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "0c95b385-e07a-43e1-aaeb-31f66deb5b35",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "42PwMZnNcspA"
      },
      "source": [
        "## TorchRec Modules and Datatypes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZdSUWBRxoP8R",
        "originalKey": "309c4d38-8f19-46d9-a8bb-7d3d1c166e84",
        "outputsInitialized": false,
        "language": "markdown",
        "showInput": false
      },
      "source": [
        "### From EmbeddingBag to EmbeddingBagCollection\n",
        "We have already explored [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) and [`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).\n",
        "\n",
        "TorchRec extends these modules by creating collections of embeddings, in other words modules that can have multiple embedding tables, with [`EmbeddingCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingCollection) and [`EmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection). We will use `EmbeddingBagCollection` to represent a group of EmbeddingBags.\n",
        "\n",
        "Here, we create an EmbeddingBagCollection (EBC) with two embedding bags, 1 representing **products** and 1 representing **users**. Each table, `product_table` and `user_table`, is represented by 64 dimension embedding of size 4096."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Iz_GZDp_oQ19",
        "originalKey": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e",
        "outputsInitialized": true,
        "language": "python",
        "customOutput": null,
        "executionStartTime": 1726000139247,
        "executionStopTime": 1726000139433,
        "serverExecutionDuration": 13.643965125084,
        "requestMsgId": "219c4ee9-c4f1-43ff-9d1c-b15b16a1dc8e",
        "output": {
          "id": "1615870128957785"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "cec92f4f-d9eb-464e-8e22-26fd910bf8c1"
      },
      "source": [
        "ebc = torchrec.EmbeddingBagCollection(\n",
        "    device=\"cpu\",\n",
        "    tables=[\n",
        "        torchrec.EmbeddingBagConfig(\n",
        "            name=\"product_table\",\n",
        "            embedding_dim=64,\n",
        "            num_embeddings=4096,\n",
        "            feature_names=[\"product\"],\n",
        "            pooling=torchrec.PoolingType.SUM,\n",
        "        ),\n",
        "        torchrec.EmbeddingBagConfig(\n",
        "            name=\"user_table\",\n",
        "            embedding_dim=64,\n",
        "            num_embeddings=4096,\n",
        "            feature_names=[\"user\"],\n",
        "            pooling=torchrec.PoolingType.SUM,\n",
        "        )\n",
        "    ]\n",
        ")\n",
        "print(ebc.embedding_bags)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ModuleDict(\n",
            "  (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
            "  (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
            ")\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "c587a298-4d38-4a69-89a2-5d5c4a26cc2c",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "xjcA0Di1cspA"
      },
      "source": [
        "Let’s inspect the forward method for EmbeddingBagcollection and the module’s inputs and outputs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "c9d2717b-b753-4e0b-97bd-1596123d081d",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000139437,
        "executionStopTime": 1726000139616,
        "serverExecutionDuration": 6.011176854372,
        "requestMsgId": "c9d2717b-b753-4e0b-97bd-1596123d081d",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "398959426640405"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UuIrEWupcspA",
        "outputId": "300c86c2-82c1-4657-fa1a-6a319eb40177"
      },
      "source": [
        "import inspect\n",
        "\n",
        "# Let's look at the EmbeddingBagCollection forward method\n",
        "# What is a KeyedJaggedTensor and KeyedTensor?\n",
        "print(inspect.getsource(ebc.forward))"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "    def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:\n",
            "        \"\"\"\n",
            "        Args:\n",
            "            features (KeyedJaggedTensor): KJT of form [F X B X L].\n",
            "\n",
            "        Returns:\n",
            "            KeyedTensor\n",
            "        \"\"\"\n",
            "        if is_non_strict_exporting() and not torch.jit.is_scripting():\n",
            "            return self._non_strict_exporting_forward(features)\n",
            "        flat_feature_names: List[str] = []\n",
            "        for names in self._feature_names:\n",
            "            flat_feature_names.extend(names)\n",
            "        inverse_indices = reorder_inverse_indices(\n",
            "            inverse_indices=features.inverse_indices_or_none(),\n",
            "            feature_names=flat_feature_names,\n",
            "        )\n",
            "        pooled_embeddings: List[torch.Tensor] = []\n",
            "        feature_dict = features.to_dict()\n",
            "        for i, embedding_bag in enumerate(self.embedding_bags.values()):\n",
            "            for feature_name in self._feature_names[i]:\n",
            "                f = feature_dict[feature_name]\n",
            "                res = embedding_bag(\n",
            "                    input=f.values(),\n",
            "                    offsets=f.offsets(),\n",
            "                    per_sample_weights=f.weights() if self._is_weighted else None,\n",
            "                ).float()\n",
            "                pooled_embeddings.append(res)\n",
            "        return KeyedTensor(\n",
            "            keys=self._embedding_names,\n",
            "            values=process_pooled_embeddings(\n",
            "                pooled_embeddings=pooled_embeddings,\n",
            "                inverse_indices=inverse_indices,\n",
            "            ),\n",
            "            length_per_key=self._lengths_per_embedding,\n",
            "        )\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "d6b9bfc2-544d-499f-ad61-d7471b819f8a",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "C_UAtHsMcspA"
      },
      "source": [
        "### TorchRec Input/Output Data Types\n",
        "TorchRec has distinct data types for input and output of its modules: `JaggedTensor`, `KeyedJaggedTensor`, and `KeyedTensor`. Now you might ask, why create new datatypes to represent sparse features? To answer that question, we must understand how sparse features are represented in code.\n",
        "\n",
        "Sparse features are otherwise known as `id_list_feature` and `id_score_list_feature`, and are the **IDs** that will be used as indices to an embedding table to retrieve the embedding for that ID. To give a very simple example, imagine a single sparse feature being Ads that a user interacted with. The input itself would be a set of Ad IDs that a user interacted with, and the embeddings retrieved would be a semantic representation of those Ads. The tricky part of representing these features in code is that in each input example, **the number of IDs is variable**. 1 day a user might have interacted with only 1 ad while the next day they interact with 3.\n",
        "\n",
        "A simple representation is shown below, where we have a `lengths` tensor denoting how many indices are in an example for a batch and a `values` tensor containing the indices themselves.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "13225ead-a798-4db2-8de6-1c13a758d676",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000139620,
        "executionStopTime": 1726000139790,
        "serverExecutionDuration": 3.692839294672,
        "requestMsgId": "13225ead-a798-4db2-8de6-1c13a758d676",
        "outputsInitialized": true,
        "customOutput": null,
        "id": "RB77aL08cspA"
      },
      "source": [
        "# Batch Size 2\n",
        "# 1 ID in example 1, 2 IDs in example 2\n",
        "id_list_feature_lengths = torch.tensor([1, 2])\n",
        "\n",
        "# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2\n",
        "id_list_feature_values = torch.tensor([5, 7, 1])"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "65d31fca-7b7f-4c0f-9ca2-56e07243a5c0",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "aKmgGqdNcspA"
      },
      "source": [
        "Let’s look at the offsets as well as what is contained in each Batch"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "9510cebd-1875-461e-9243-53928632abfa",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000139794,
        "executionStopTime": 1726000139966,
        "serverExecutionDuration": 6.6289491951466,
        "requestMsgId": "9510cebd-1875-461e-9243-53928632abfa",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "869913611744322"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "t5T5S8_mcspB",
        "outputId": "87e78b11-7497-4387-c0bc-b4d277ba8ab3"
      },
      "source": [
        "# Lengths can be converted to offsets for easy indexing of values\n",
        "id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)\n",
        "\n",
        "print(\"Offsets: \", id_list_feature_offsets)\n",
        "print(\"First Batch: \", id_list_feature_values[: id_list_feature_offsets[0]])\n",
        "print(\n",
        "    \"Second Batch: \",\n",
        "    id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],\n",
        ")"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Offsets:  tensor([1, 3])\n",
            "First Batch:  tensor([5])\n",
            "Second Batch:  tensor([7, 1])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000139968,
        "executionStopTime": 1726000140161,
        "serverExecutionDuration": 7.3191449046135,
        "requestMsgId": "4bc3fac5-16b9-4f63-b841-9b26ee0ccfc0",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1254783359215069"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2OOK2BBecspB",
        "outputId": "b27c1547-fbb7-47c4-efb6-48aff3300d1a"
      },
      "source": [
        "from torchrec import JaggedTensor\n",
        "\n",
        "# JaggedTensor is just a wrapper around lengths/offsets and values tensors!\n",
        "jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)\n",
        "\n",
        "# Automatically compute offsets from lengths\n",
        "print(\"Offsets: \", jt.offsets())\n",
        "\n",
        "# Convert to list of values\n",
        "print(\"List of Values: \", jt.to_dense())\n",
        "\n",
        "# __str__ representation\n",
        "print(jt)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Offsets:  tensor([0, 1, 3])\n",
            "List of Values:  [tensor([5]), tensor([7, 1])]\n",
            "JaggedTensor({\n",
            "    [[5], [7, 1]]\n",
            "})\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "ad069058-2329-4ab9-bee8-60775ead4c33",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000140165,
        "executionStopTime": 1726000140355,
        "serverExecutionDuration": 10.361641645432,
        "requestMsgId": "ad069058-2329-4ab9-bee8-60775ead4c33",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "530006499497328"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fs10Fxu2cspB",
        "outputId": "3bc754c2-ac30-4c01-b0c8-7f98eedb9c52"
      },
      "source": [
        "from torchrec import KeyedJaggedTensor\n",
        "\n",
        "# JaggedTensor represents IDs for 1 feature, but we have multiple features in an EmbeddingBagCollection\n",
        "# That's where KeyedJaggedTensor comes in! KeyedJaggedTensor is just multiple JaggedTensors for multiple id_list_feature_offsets\n",
        "# From before, we have our two features \"product\" and \"user\". Let's create JaggedTensors for both!\n",
        "\n",
        "product_jt = JaggedTensor(\n",
        "    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])\n",
        ")\n",
        "user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))\n",
        "\n",
        "# Q1: How many batches are there, and which values are in the first batch for product_jt and user_jt?\n",
        "kjt = KeyedJaggedTensor.from_jt_dict({\"product\": product_jt, \"user\": user_jt})\n",
        "\n",
        "# Look at our feature keys for the KeyedJaggedTensor\n",
        "print(\"Keys: \", kjt.keys())\n",
        "\n",
        "# Look at the overall lengths for the KeyedJaggedTensor\n",
        "print(\"Lengths: \", kjt.lengths())\n",
        "\n",
        "# Look at all values for KeyedJaggedTensor\n",
        "print(\"Values: \", kjt.values())\n",
        "\n",
        "# Can convert KJT to dictionary representation\n",
        "print(\"to_dict: \", kjt.to_dict())\n",
        "\n",
        "# KeyedJaggedTensor(KJT) string representation\n",
        "print(kjt)\n",
        "\n",
        "# Q2: What are the offsets for the KeyedJaggedTensor?"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Keys:  ['product', 'user']\n",
            "Lengths:  tensor([3, 1, 2, 2])\n",
            "Values:  tensor([1, 2, 1, 5, 2, 3, 4, 1])\n",
            "to_dict:  {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7953dd7428f0>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7953dd7434f0>}\n",
            "KeyedJaggedTensor({\n",
            "    \"product\": [[1, 2, 1], [5]],\n",
            "    \"user\": [[2, 3], [4, 1]]\n",
            "})\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "b13fdf10-45a7-4e57-b50e-cc18547a715b",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000140357,
        "executionStopTime": 1726000140549,
        "serverExecutionDuration": 17.695877701044,
        "requestMsgId": "b13fdf10-45a7-4e57-b50e-cc18547a715b",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "496557126663787"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JeLwyCNRcspB",
        "outputId": "0723906d-0aba-4d48-e9a5-7ac618d711c5"
      },
      "source": [
        "# Now we can run a forward pass on our ebc from before\n",
        "result = ebc(kjt)\n",
        "result"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<torchrec.sparse.jagged_tensor.KeyedTensor at 0x7953dd7420e0>"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "57a01464-de39-4bfb-8355-83cd97e519c0",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000140552,
        "executionStopTime": 1726000140732,
        "serverExecutionDuration": 6.0368701815605,
        "requestMsgId": "57a01464-de39-4bfb-8355-83cd97e519c0",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1457290878317732"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "R2K4v2vqcspB",
        "outputId": "c507e47e-32bc-4440-8c8a-2b3ea334467c"
      },
      "source": [
        "# Result is a KeyedTensor, which contains a list of the feature names and the embedding results\n",
        "print(result.keys())\n",
        "\n",
        "# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined\n",
        "# 128 for dimension of embedding. If you look at where we initialized the EmbeddingBagCollection, we have two tables \"product\" and \"user\" of dimension 64 each\n",
        "# meaning emebddings for both features are of size 64. 64 + 64 = 128\n",
        "print(result.values().shape)\n",
        "\n",
        "# Nice to_dict method to determine the embeddings that belong to each feature\n",
        "result_dict = result.to_dict()\n",
        "for key, embedding in result_dict.items():\n",
        "    print(key, embedding.shape)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['product', 'user']\n",
            "torch.Size([2, 128])\n",
            "product torch.Size([2, 64])\n",
            "user torch.Size([2, 64])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "d0fc8635-dac3-444b-978b-421b5d77b70c",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "EE-YYDv7cspB"
      },
      "source": [
        "Congrats! Give yourself a pat on the back for making it this far."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "70816a78-7671-411c-814f-d2c98c3a912c",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "djLHn0CIcspB"
      },
      "source": [
        "## Distributed Training and Sharding\n",
        "Now that we have a grasp on TorchRec modules and data types, it's time to take it to the next level.\n",
        "\n",
        "Remember, TorchRec's main purpose is to provide primitives for distributed embeddings. So far, we've only worked with embedding tables on 1 device. This has been possible given how small the embedding tables have been, but in a production setting this isn't generally the case. Embedding tables often get massive, where 1 table can't fit on a single GPU, creating the requirement for multiple devices and a distributed environment\n",
        "\n",
        "In this section, we will explore setting up a distributed environment, exactly how actual production training is done, and explore sharding embedding tables, all with Torchrec.\n",
        "\n",
        "**This section will also only use 1 gpu, though it will be treated in a distributed fashion. This is only a limitation for training, as training has a process per gpu. Inference does not run into this requirement**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4-v17rxkopQw",
        "originalKey": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4",
        "outputsInitialized": true,
        "language": "python",
        "customOutput": null,
        "executionStartTime": 1726000140740,
        "executionStopTime": 1726000142256,
        "serverExecutionDuration": 1350.0418178737,
        "requestMsgId": "df0d09f0-5e8e-46bf-a086-dd991c8be0b4",
        "output": {
          "id": "1195358511578142"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "efa5689f-b794-41a6-8c06-86856e1e698a"
      },
      "source": [
        "# Here we set up our torch distributed environment\n",
        "# WARNING: You can only call this cell once, calling it again will cause an error\n",
        "# as you can only initialize the process group once\n",
        "\n",
        "import os\n",
        "\n",
        "import torch.distributed as dist\n",
        "\n",
        "# Set up environment variables for distributed training\n",
        "# RANK is which GPU we are on, default 0\n",
        "os.environ[\"RANK\"] = \"0\"\n",
        "# How many devices in our \"world\", since Bento can only handle 1 process, 1 GPU\n",
        "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
        "# Localhost as we are training locally\n",
        "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
        "# Port for distributed training\n",
        "os.environ[\"MASTER_PORT\"] = \"29500\"\n",
        "\n",
        "# Note - you will need a V100 or A100 to run tutorial as!\n",
        "# nccl backend is for GPUs, gloo is for CPUs\n",
        "dist.init_process_group(backend=\"gloo\")\n",
        "\n",
        "print(f\"Distributed environment initialized: {dist}\")"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "480e69dc-3e9d-4e86-b73c-950e18afb0f5",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "hQOjNci3cspB"
      },
      "source": [
        "### Distributed Embeddings\n",
        "\n",
        "We have already worked with the main TorchRec module: `EmbeddingBagCollection`. We have examined how it works along with how data is represented in TorchRec. However, we have not yet explored one of the main parts of TorchRec, which is **distributed embeddings**.\n",
        "\n",
        "GPUs are the most popular choice for ML workloads by far today, as they are able to do magnitudes more floating point operations/s ([FLOPs](https://en.wikipedia.org/wiki/FLOPS)) than CPU. However, GPUs come with the limitation of scarce fast memory (HBM which is analgous to RAM for CPU), typically ~10s of GBs.\n",
        "\n",
        "A RecSys model can contain embedding tables that far exceed the memory limit for 1 GPU, hence the need for distribution of the embedding tables across multiple GPUs, otherwise known as **model parallel**. On the other hand, **data parallel** is where the entire model is replicated on each GPU, which each GPU taking in a distinct batch of data for training, syncing gradients on the backwards pass.\n",
        "\n",
        "Parts of the model that **require less compute but more memory (embeddings) are distributed with model parallel** while parts that **require more compute and less memory (dense layers, MLP, etc.) are distributed with data parallel**.\n",
        "\n",
        "\n",
        "### Sharding\n",
        "In order to distribute an embedding table, we split up the embedding table into parts and place those parts onto different devices, also known as “sharding”.\n",
        "\n",
        "There are many ways to shard embedding tables. The most common ways are:\n",
        "* Table-Wise: the table is placed entirely onto one device\n",
        "* Column-Wise: columns of embedding tables are sharded\n",
        "* Row-Wise: rows of embedding tables are sharded\n",
        "\n",
        "\n",
        "### Sharded Modules\n",
        "While all of this seems like a lot to deal with and implement, you're in luck. **TorchRec provides all the primitives for easy distributed training/inference**! In fact, TorchRec modules have two corresponding classes for working with any TorchRec module in a distributed environment:\n",
        "1. The module sharder: This class exposes a `shard` API that handles sharding a TorchRec Module, producing a sharded module.\n",
        "    *  For `EmbeddingBagCollection`, the sharder is [`EmbeddingBagCollectionSharder`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder)\n",
        "2. Sharded module: This class is a sharded variant of a TorchRec module. It has the same input/output as a the regular TorchRec module, but much more optimized and works in a distributed environment.\n",
        "    * For `EmbeddingBagCollection`, the sharded variant is [`ShardedEmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection)\n",
        "\n",
        "Every TorchRec module has an unsharded and sharded variant.\n",
        "* The unsharded version is meant to be prototyped and experimented with\n",
        "* The sharded version is meant to be used in a distributed environment for distributed training/inference.\n",
        "\n",
        "The sharded versions of TorchRec modules, for example EmbeddingBagCollection, will handle everything that is needed for Model Parallelism, such as communication between GPUs for distributing embeddings to the correct GPUs.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "eb2a064d-0b67-4cba-a199-c99573c7e6cd",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000142261,
        "executionStopTime": 1726000142430,
        "serverExecutionDuration": 8.3460621535778,
        "requestMsgId": "eb2a064d-0b67-4cba-a199-c99573c7e6cd",
        "customOutput": null,
        "outputsInitialized": true,
        "output": {
          "id": "791089056311464"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FX65VcQ6cspB",
        "outputId": "1aa3bc52-569b-46fb-8a94-cd1873e987ca"
      },
      "source": [
        "# Refresher of our EmbeddingBagCollection module\n",
        "ebc"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "EmbeddingBagCollection(\n",
              "  (embedding_bags): ModuleDict(\n",
              "    (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
              "    (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "1442636d-8617-4785-b40c-8544374253b6",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "outputsInitialized": true,
        "executionStartTime": 1726000142433,
        "executionStopTime": 1726000142681,
        "serverExecutionDuration": 4.4135116040707,
        "requestMsgId": "1442636d-8617-4785-b40c-8544374253b6",
        "customOutput": null,
        "output": {
          "id": "502189589096046"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1hSzTg4pcspC",
        "outputId": "d7d86592-4fdc-4d0b-f2ba-a40a80af1fcf"
      },
      "source": [
        "from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n",
        "from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology\n",
        "from torchrec.distributed.types import ShardingEnv\n",
        "\n",
        "# Corresponding sharder for EmbeddingBagCollection module\n",
        "sharder = EmbeddingBagCollectionSharder()\n",
        "\n",
        "# ProcessGroup from torch.distributed initialized 2 cells above\n",
        "pg = dist.GroupMember.WORLD\n",
        "assert pg is not None, \"Process group is not initialized\"\n",
        "\n",
        "print(f\"Process Group: {pg}\")"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7953dd7942b0>\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "29cc17eb-9e2f-480b-aed2-60b15024fbf7",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "qU7A980qcspC"
      },
      "source": [
        "### Planner\n",
        "\n",
        "Before we can show how sharding works, we must know about the **planner**, which helps us determine the best sharding configuration.\n",
        "\n",
        "Given a number of embedding tables and a number of ranks, there are many different sharding configurations that are possible. For example, given 2 embedding tables and 2 GPUs, you can:\n",
        "* Place 1 table on each GPU\n",
        "* Place both tables on a single GPU and no tables on the other\n",
        "* Place certain rows/columns on each GPU\n",
        "\n",
        "Given all of these possibilities, we typically want a sharding configuration that is optimal for performance.\n",
        "\n",
        "That is where the planner comes in. The planner is able to determine given the # of embedding tables and the # of GPUs, what is the optimal configuration. Turns out, this is incredibly difficult to do manually, with tons of factors that engineers have to consider to ensure an optimal sharding plan. Luckily, TorchRec provides an auto planner when the planner is used. The TorchRec planner:\n",
        "* assesses memory constraints of hardware,\n",
        "* estimates compute based on memory fetches as embedding lookups,\n",
        "* addresses data specific factors,\n",
        "* considers other hardware specifics like bandwidth to generate an optimal sharding plan.\n",
        "\n",
        "In order to take into consideration all these variables, The TorchRec planner can take in [various amounts of data for embedding tables, constraints, hardware information, and topology](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155) to aid in generating the optimal sharding plan for a model, which is routinely provided across stacks\n",
        "\n",
        "\n",
        "To learn more about sharding, see our [sharding tutorial](https://pytorch.org/tutorials/advanced/sharding.html)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "64936203-2e59-4bc3-8d76-1b652b7891c2",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000142687,
        "executionStopTime": 1726000143033,
        "serverExecutionDuration": 145.92137932777,
        "requestMsgId": "64936203-2e59-4bc3-8d76-1b652b7891c2",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1247084956198777"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PQeXnuAGcspC",
        "outputId": "e6d366b7-f064-4e38-a477-fc2e0d4b9736"
      },
      "source": [
        "# In our case, 1 GPU and compute on CUDA device\n",
        "planner = EmbeddingShardingPlanner(\n",
        "    topology=Topology(\n",
        "        world_size=1,\n",
        "        compute_device=\"cuda\",\n",
        "    )\n",
        ")\n",
        "\n",
        "# Run planner to get plan for sharding\n",
        "plan = planner.collective_plan(ebc, [sharder], pg)\n",
        "\n",
        "print(f\"Sharding Plan generated: {plan}\")"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sharding Plan generated: module: \n",
            "\n",
            "    param     | sharding type | compute kernel | ranks\n",
            "------------- | ------------- | -------------- | -----\n",
            "product_table | table_wise    | fused          | [0]  \n",
            "user_table    | table_wise    | fused          | [0]  \n",
            "\n",
            "    param     | shard offsets | shard sizes |   placement  \n",
            "------------- | ------------- | ----------- | -------------\n",
            "product_table | [0, 0]        | [4096, 64]  | rank:0/cuda:0\n",
            "user_table    | [0, 0]        | [4096, 64]  | rank:0/cuda:0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "bbbbbf60-5691-4357-9943-4d7f8b2b1d5c",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "2TTLj_0PcspC"
      },
      "source": [
        "### Planner Result\n",
        "As you can see, when running the planner there is quite a bit of output above. We can see a ton of stats being calculated along with where our tables end up being placed.\n",
        "\n",
        "The result of running the planner is a static plan, which can be reused for sharding! This allows sharding to be static for production models instead of determining a new sharding plan everytime. Below, we use the sharding plan to finally generate our `ShardedEmbeddingBagCollection.`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "533be12d-a3c5-4c9e-9351-7770251c8fa5",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000143037,
        "executionStopTime": 1726000143259,
        "serverExecutionDuration": 5.2368640899658,
        "requestMsgId": "533be12d-a3c5-4c9e-9351-7770251c8fa5",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "901470115170971"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JIci5Gz6cspC",
        "outputId": "32d3cb73-80b6-4646-9c1d-3cff5a498f86"
      },
      "source": [
        "# The static plan that was generated\n",
        "plan"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ShardingPlan(plan={'': {'product_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None), 'user_table': ParameterSharding(sharding_type='table_wise', compute_kernel='fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]), cache_params=None, enforce_hbm=None, stochastic_rounding=None, bounds_check_mode=None, output_dtype=None)}})"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000143262,
        "executionStopTime": 1726000147680,
        "serverExecutionDuration": 4229.5375689864,
        "requestMsgId": "5dcbfda0-0abb-4a51-ba8f-a6a4023f0e2f",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1231077634880712"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2__Do2tqcspC",
        "outputId": "7c19830a-9a11-4fbe-ddf3-8c3cb8a6c3b3"
      },
      "source": [
        "env = ShardingEnv.from_process_group(pg)\n",
        "\n",
        "# Shard the EmbeddingBagCollection module using the EmbeddingBagCollectionSharder\n",
        "sharded_ebc = sharder.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
        "\n",
        "print(f\"Sharded EBC Module: {sharded_ebc}\")"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sharded EBC Module: ShardedEmbeddingBagCollection(\n",
            "  (lookups): \n",
            "   GroupedPooledEmbeddingsLookup(\n",
            "      (_emb_modules): ModuleList(\n",
            "        (0): BatchedFusedEmbeddingBag(\n",
            "          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
            "        )\n",
            "      )\n",
            "    )\n",
            "   (_output_dists): \n",
            "   TwPooledEmbeddingDist()\n",
            "  (embedding_bags): ModuleDict(\n",
            "    (product_table): Module()\n",
            "    (user_table): Module()\n",
            "  )\n",
            ")\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "ErXXbYzJmVzI"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "3ba44a6d-a6f7-4da2-83a6-e8ac974c64ac",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "QBLpkKYIcspC"
      },
      "source": [
        "#### Awaitable\n",
        "Remember that TorchRec is a highly optimized library for distributed embeddings. A concept that TorchRec introduces to enable higher performance for training on GPU is a [`LazyAwaitable`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable). You will see `LazyAwaitable` types as outputs of various sharded TorchRec modules. All a `LazyAwaitable` does is delay calculating some result as long as possible, and it does it by acting like an async type."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "e450dc00-bd30-4bc2-8c71-4c01979b0948",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000147687,
        "executionStopTime": 1726000147874,
        "serverExecutionDuration": 9.098757058382,
        "requestMsgId": "e450dc00-bd30-4bc2-8c71-4c01979b0948",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1236006950908310"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rwYzKwyNcspC",
        "outputId": "a8979bb2-6fac-4db6-b997-c31962014e24"
      },
      "source": [
        "from typing import List\n",
        "\n",
        "from torchrec.distributed.types import LazyAwaitable\n",
        "\n",
        "\n",
        "# Demonstrate a LazyAwaitable type\n",
        "class ExampleAwaitable(LazyAwaitable[torch.Tensor]):\n",
        "    def __init__(self, size: List[int]) -> None:\n",
        "        super().__init__()\n",
        "        self._size = size\n",
        "\n",
        "    def _wait_impl(self) -> torch.Tensor:\n",
        "        return torch.ones(self._size)\n",
        "\n",
        "\n",
        "awaitable = ExampleAwaitable([3, 2])\n",
        "awaitable.wait()"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[1., 1.],\n",
              "        [1., 1.],\n",
              "        [1., 1.]])"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000147878,
        "executionStopTime": 1726000154861,
        "serverExecutionDuration": 6806.3651248813,
        "requestMsgId": "c958c791-a62c-423a-9a95-1e6ae4e8fbd9",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1255627342282843"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cs41RfzGcspC",
        "outputId": "ba7cfdbe-59c4-48c2-a767-d6f5f5cbb915"
      },
      "source": [
        "kjt = kjt.to(\"cuda\")\n",
        "output = sharded_ebc(kjt)\n",
        "# The output of our sharded EmbeddingBagCollection module is a an Awaitable?\n",
        "print(output)"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7953dd7c0430>\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000154865,
        "executionStopTime": 1726000155069,
        "serverExecutionDuration": 6.0432851314545,
        "requestMsgId": "6f2957f2-2e7e-47e4-9237-f0b6c8b0da94",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1057638405967561"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_1sdt75rcspG",
        "outputId": "365cf6cb-adae-4b41-f68d-1a7acf5fe332"
      },
      "source": [
        "kt = output.wait()\n",
        "# Now we have out KeyedTensor after calling .wait()\n",
        "# If you are confused as to why we have a KeyedTensor output,\n",
        "# give yourself a refresher on the unsharded EmbeddingBagCollection module\n",
        "print(type(kt))\n",
        "\n",
        "print(kt.keys())\n",
        "\n",
        "print(kt.values().shape)\n",
        "\n",
        "# Same output format as unsharded EmbeddingBagCollection\n",
        "result_dict = kt.to_dict()\n",
        "for key, embedding in result_dict.items():\n",
        "    print(key, embedding.shape)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>\n",
            "['product', 'user']\n",
            "torch.Size([2, 128])\n",
            "product torch.Size([2, 64])\n",
            "user torch.Size([2, 64])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "4c464de0-20ef-4ef2-89e2-5d58ca224660",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "bEgB987CcspG"
      },
      "source": [
        "### Anatomy of Sharded TorchRec modules\n",
        "\n",
        "We have now successfully sharded an EmbeddingBagCollection given a sharding plan that we generated! The sharded module has common APIs from TorchRec which abstract away distributed communication/compute amongst multiple GPUs. In fact, these APIs are highly optimized for performance in training and inference. **Below are the three common APIs for distributed training/inference** that are provided by TorchRec:\n",
        "\n",
        "1. **input_dist**: Handles distributing inputs from GPU to GPU\n",
        "\n",
        "2. **lookups**: Does the actual embedding lookup in an optimized, batched manner using FBGEMM TBE (more on this later)\n",
        "\n",
        "3. **output_dist**: Handles distributing outputs from GPU to GPU\n",
        "\n",
        "The distribution of inputs/outputs is done through [NCCL Collectives](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html), namely [All-to-Alls](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all), which is where all GPUs send/receive data to and from one another. TorchRec interfaces with PyTorch distributed for collectives and provides clean abstractions to the end users, removing the concern for the lower level details.\n",
        "\n",
        "\n",
        "The backwards pass does all of these collectives but in the reverse order for distribution of gradients. input_dist, lookup, and output_dist all depend on the sharding scheme. Since we sharded in a table-wise fashion, these APIs are modules that are constructed by [TwPooledEmbeddingSharding](https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding).\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "03e6e163-af3a-4443-a5a8-3f877fc401d2",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000155075,
        "executionStopTime": 1726000155253,
        "serverExecutionDuration": 5.8192722499371,
        "requestMsgId": "03e6e163-af3a-4443-a5a8-3f877fc401d2",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1042737524520351"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "O2ptES89cspG",
        "outputId": "2b801648-6501-4463-d743-4887da340974"
      },
      "source": [
        "sharded_ebc"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ShardedEmbeddingBagCollection(\n",
              "  (lookups): \n",
              "   GroupedPooledEmbeddingsLookup(\n",
              "      (_emb_modules): ModuleList(\n",
              "        (0): BatchedFusedEmbeddingBag(\n",
              "          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "   (_input_dists): \n",
              "   TwSparseFeaturesDist(\n",
              "      (_dist): KJTAllToAll()\n",
              "    )\n",
              "   (_output_dists): \n",
              "   TwPooledEmbeddingDist(\n",
              "      (_dist): PooledEmbeddingsAllToAll()\n",
              "    )\n",
              "  (embedding_bags): ModuleDict(\n",
              "    (product_table): Module()\n",
              "    (user_table): Module()\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000155256,
        "executionStopTime": 1726000155442,
        "serverExecutionDuration": 5.3565315902233,
        "requestMsgId": "c2a34340-d5fd-4dc8-9b7e-3a761a0c5f82",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1063399165221115"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PHjJt3BQcspG",
        "outputId": "aaedadd8-da43-4225-d5f8-a7a43fd0250a"
      },
      "source": [
        "# Distribute input KJTs to all other GPUs and receive KJTs\n",
        "sharded_ebc._input_dists"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[TwSparseFeaturesDist(\n",
              "   (_dist): KJTAllToAll()\n",
              " )]"
            ]
          },
          "metadata": {},
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "88abe892-1ed1-4806-84ad-35f43247a772",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000155445,
        "executionStopTime": 1726000155695,
        "serverExecutionDuration": 5.3521953523159,
        "requestMsgId": "88abe892-1ed1-4806-84ad-35f43247a772",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1513800839249249"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jrEXMc7TcspG",
        "outputId": "81ab40a5-135b-494c-f2bc-91be16a338cc"
      },
      "source": [
        "# Distribute output embeddingts to all other GPUs and receive embeddings\n",
        "sharded_ebc._output_dists"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[TwPooledEmbeddingDist(\n",
              "   (_dist): PooledEmbeddingsAllToAll()\n",
              " )]"
            ]
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "2eaf16f1-ac14-4f7a-b443-e707ff85c3f0",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "C2jfo5ilcspH"
      },
      "source": [
        "### Optimizing Embedding Lookups\n",
        "\n",
        "In performing lookups for a collection of embedding tables, a trivial solution would be to iterate through all the `nn.EmbeddingBags` and do a lookup per table. This is exactly what the standard, unsharded TorchRec's `EmbeddingBagCollection` does. However, while this solution is simple, it is extremely slow.\n",
        "\n",
        "[FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu) is a library that provides GPU operators (otherewise known as kernels) that are very optimized. One of these operators is known as **Table Batched Embedding** (TBE), provides two major optimizations:\n",
        "\n",
        "* Table batching, which allows you to look up multiple embeddings with one kernel call.\n",
        "* Optimizer Fusion, which allows the module to update itself given the canonical pytorch optimizers and arguments.\n",
        "\n",
        "The `ShardedEmbeddingBagCollection` uses the FBGEMM TBE as the lookup instead of traditional `nn.EmbeddingBags` for optimized embedding lookups."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000155699,
        "executionStopTime": 1726000155879,
        "serverExecutionDuration": 5.0756596028805,
        "requestMsgId": "801c50b9-e1a2-465a-9fa3-3cd87d676ed4",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "911093750838903"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1GoHWI6OcspH",
        "outputId": "cd67815b-00bd-4a30-89cf-7b5d9c7051e9"
      },
      "source": [
        "sharded_ebc._lookups"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[GroupedPooledEmbeddingsLookup(\n",
              "   (_emb_modules): ModuleList(\n",
              "     (0): BatchedFusedEmbeddingBag(\n",
              "       (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
              "     )\n",
              "   )\n",
              " )]"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "f2b31d78-81a9-426f-b017-ca8404383939",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "1zcbZX1lcspH"
      },
      "source": [
        "### DistributedModelParallel\n",
        "\n",
        "We have now explored sharding a single EmbeddingBagCollection! We were able to take the `EmbeddingBagCollectionSharder` and use the unsharded `EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection` module. This workflow is fine, but typically when doing model parallel, [`DistributedModelParallel`](https://pytorch.org/torchrec/model-parallel-api-reference.html#model-parallel) (DMP) is used as the standard interface. When wrapping your model (in our case `ebc`), with DMP, the following will occur:\n",
        "\n",
        "1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e, the EmbeddingBagCollection)\n",
        "2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).\n",
        "\n",
        "DMP takes in everything that we've just experimented with, like a static sharding plan, a list of sharders, etc. However, it also has some nice defaults to seamlessly shard a TorchRec model. In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000155883,
        "executionStopTime": 1726000156073,
        "serverExecutionDuration": 7.8761726617813,
        "requestMsgId": "e0e198e1-db2a-46b0-91f0-51a5ff80abbb",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1207953610328397"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ypVUDwpzcspH",
        "outputId": "26a7a957-c231-459a-dfc8-f0c1cd6f697e"
      },
      "source": [
        "ebc"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "EmbeddingBagCollection(\n",
              "  (embedding_bags): ModuleDict(\n",
              "    (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
              "    (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "73fec38d-947a-49d5-a2ba-61e3828b7117",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000156075,
        "executionStopTime": 1726000156438,
        "serverExecutionDuration": 165.43522849679,
        "requestMsgId": "73fec38d-947a-49d5-a2ba-61e3828b7117",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1838328716783594"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5EdlyWAycspH",
        "outputId": "05e90aa2-cb83-4ddf-9da5-6aa31d6da278"
      },
      "source": [
        "model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))"
      ],
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "f8d87a4e-6a7a-4a02-92f9-9baa794266af",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000156441,
        "executionStopTime": 1726000156665,
        "serverExecutionDuration": 6.8417005240917,
        "requestMsgId": "f8d87a4e-6a7a-4a02-92f9-9baa794266af",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1059040285804352"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "b5NgRErjcspH",
        "outputId": "8f4de40c-3a3b-43e5-d645-814bf03dab0b"
      },
      "source": [
        "out = model(kjt)\n",
        "out.wait()"
      ],
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<torchrec.sparse.jagged_tensor.KeyedTensor at 0x7953dd7dcd90>"
            ]
          },
          "metadata": {},
          "execution_count": 31
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "e7e02648-dee7-4b3a-8953-47e8b8771c3b",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000156669,
        "executionStopTime": 1726000156885,
        "serverExecutionDuration": 5.4804161190987,
        "requestMsgId": "e7e02648-dee7-4b3a-8953-47e8b8771c3b",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "3346626825643095"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VJrSysjjcspH",
        "outputId": "2920a3d0-dd96-43ea-ab0d-627de00d1e42"
      },
      "source": [
        "model"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DistributedModelParallel(\n",
              "  (_dmp_wrapped_module): ShardedEmbeddingBagCollection(\n",
              "    (lookups): \n",
              "     GroupedPooledEmbeddingsLookup(\n",
              "        (_emb_modules): ModuleList(\n",
              "          (0): BatchedFusedEmbeddingBag(\n",
              "            (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "     (_input_dists): \n",
              "     TwSparseFeaturesDist(\n",
              "        (_dist): KJTAllToAll()\n",
              "      )\n",
              "     (_output_dists): \n",
              "     TwPooledEmbeddingDist(\n",
              "        (_dist): PooledEmbeddingsAllToAll()\n",
              "      )\n",
              "    (embedding_bags): ModuleDict(\n",
              "      (product_table): Module()\n",
              "      (user_table): Module()\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "4b6171d5-ae60-4cc8-a47a-f01236c02e6c",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "BLM673eTcspH"
      },
      "source": [
        "### Sharding Best Practices\n",
        "\n",
        "Currently, our configuration is only sharding on 1 GPU (or rank), which is trivial: just place all the tables on 1 GPUs memory. However, in real production use cases, embedding tables are **typically sharded on hundreds of GPUs**, with different sharding methods such as table-wise, row-wise, and column-wise. It is incredibly important to determine a proper sharding configuration (to prevent out of memory issues) while keeping it balanced not only in terms of memory but also compute for optimal performance."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Adding in the Optimizer\n",
        "\n",
        "Remember that TorchRec modules are hyperoptimized for large scale distributed training. An important optimization is in regards to the optimizer. **TorchRec modules provide a seamless API to fuse the backwards pass and optimize step in training, providing a significant optimization in performance and decreasing the memory used, alongside granularity in assigning distinct optimizers to distinct model parameters.**\n",
        "\n",
        "#### Optimizer Classes\n",
        "\n",
        "TorchRec uses `CombinedOptimizer`, which contains a collection of `KeyedOptimizers`. A `CombinedOptimizer` effectively makes it easy to handle multiple optimizers for various sub groups in the model. A `KeyedOptimizer` extends the `torch.optim.Optimizer` and is initialized through a dictionary of parameters exposes the parameters. Each `TBE` module in a `EmbeddingBagCollection` will have it's own `KeyedOptimizer` which combines into one `CombinedOptimizer`.\n",
        "\n",
        "#### Fused optimizer in TorchRec\n",
        "\n",
        "Using `DistributedModelParallel`, the **optimizer is fused, which means that the optimizer update is done in the backward**. This is an optimization in TorchRec and FBGEMM, where the optimizer embedding gradients are not materialized and applied directly to the parameters. This brings significant memory savings as embedding gradients are typically size of the parameters themselves.\n",
        "\n",
        "You can, however, choose to make the optimizer `dense` which does not apply this optimization and let's you inspect the embedding gradients or apply computations to it as you wish. A dense optimizer in this case would be your [canonical PyTorch model training loop with optimizer.](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html)\n",
        "\n",
        "Once the optimizer is created through `DistributedModelParallel`, you still need to manage an optimizer for the other parameters not associated with TorchRec embedding modules. To find the other parameters, use`in_backward_optimizer_filter(model.named_parameters())`.\n",
        "\n",
        "Apply an optimizer to those parameters as you would a normal Torch optimizer and combine this and the `model.fused_optimizer` into one `CombinedOptimizer` that you can use in your training loop to `zero_grad` and `step` through.\n",
        "\n",
        "#### Let's add an optimizer to our EmbeddingBagCollection\n",
        "We will do this in two ways, which are equivalent, but give you options depending on your preferences:\n",
        "1. Passing optimizer kwargs through fused parameters (fused_params) in sharder\n",
        "2. Through `apply_optimizer_in_backward`\n",
        "Note: `apply_optimizer_in_backward` converts the optimizer parameters to `fused_params` to pass to the `TBE` in the `EmbeddingBagCollection`/`EmbeddingCollection`."
      ],
      "metadata": {
        "id": "zFhggkUCmd7f"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Approach 1: passing optimizer kwargs through fused parameters\n",
        "from torchrec.optim.optimizers import in_backward_optimizer_filter\n",
        "from fbgemm_gpu.split_embedding_configs import EmbOptimType\n",
        "\n",
        "\n",
        "# We initialize the sharder with\n",
        "fused_params = {\n",
        "    \"optimizer\": EmbOptimType.EXACT_ROWWISE_ADAGRAD,\n",
        "    \"learning_rate\": 0.02,\n",
        "    \"eps\": 0.002,\n",
        "}\n",
        "\n",
        "# Init sharder with fused_params\n",
        "sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)\n",
        "\n",
        "# We'll use same plan and unsharded EBC as before but this time with our new sharder\n",
        "sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
        "\n",
        "# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correclty.\n",
        "# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied\n",
        "print(f\"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}\")\n",
        "print(f\"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}\")\n",
        "\n",
        "print(f\"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h5BCEFidmnEw",
        "outputId": "202c64f7-ae95-4b0d-9f53-16138a680d7d"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (\n",
            "Parameter Group 0\n",
            "    lr: 0.01\n",
            ")\n",
            "Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (\n",
            "Parameter Group 0\n",
            "    lr: 0.02\n",
            ")\n",
            "Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n",
        "import copy\n",
        "# Approach 2: applying optimizer through apply_optimizer_in_backward\n",
        "# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it\n",
        "\n",
        "# We can achieve the same result as we did in the previous\n",
        "ebc_apply_opt = copy.deepcopy(ebc)\n",
        "optimizer_kwargs = {\"lr\": 0.5}\n",
        "\n",
        "for name, param in ebc_apply_opt.named_parameters():\n",
        "    print(f\"{name=}\")\n",
        "    apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)\n",
        "\n",
        "sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[\"\"], env, torch.device(\"cuda\"))\n",
        "\n",
        "# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted\n",
        "print(sharded_ebc_apply_opt.fused_optimizer)\n",
        "print(type(sharded_ebc_apply_opt.fused_optimizer))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "T-xx724MmoKv",
        "outputId": "0f58fb18-f423-4c84-ee57-d37bdba28eb8"
      },
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "name='embedding_bags.product_table.weight'\n",
            "name='embedding_bags.user_table.weight'\n",
            ": EmbeddingFusedOptimizer (\n",
            "Parameter Group 0\n",
            "    lr: 0.5\n",
            ")\n",
            "<class 'torchrec.optim.keyed.CombinedOptimizer'>\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-34-92db06a959c3>:1: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.\n",
            "  from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# We can also check through the filter other parameters that aren't associated with the \"fused\" optimizer(s)\n",
        "# Pratically, just non TorchRec module parameters. Since our module is just a TorchRec EBC\n",
        "# there are no other parameters that aren't associated with TorchRec\n",
        "print(\"Non Fused Model Parameters:\")\n",
        "print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UEyhlmbmlwsW",
        "outputId": "f6219673-d14d-444e-a451-98f33ddeb54d"
      },
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Non Fused Model Parameters:\n",
            "dict_keys([])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Here we do a dummy backwards call and see that parameter updates for fused\n",
        "# optimizers happen as a result of the backward pass\n",
        "\n",
        "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n",
        "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n",
        "print(f\"First Iteration Loss: {loss}\")\n",
        "\n",
        "loss.backward()\n",
        "\n",
        "ebc_output = sharded_ebc_fused_params(kjt).wait().values()\n",
        "loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n",
        "# We don't call an optimizer.step(), so for the loss to have changed here,\n",
        "# that means that the gradients were somehow updated, which is what the\n",
        "# fused optimizer automatically handles for us\n",
        "print(f\"Second Iteration Loss: {loss}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Bga-zM2OfnMW",
        "outputId": "6c5d45f2-c479-4932-b39d-5ff8abe27d3c"
      },
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "First Iteration Loss: 255.94378662109375\n",
            "Second Iteration Loss: 245.72166442871094\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "e3bdc895-54c4-4fc6-9175-28dd75021c6a",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "Xc-RUDwDcspH"
      },
      "source": [
        "## Inference\n",
        "\n",
        "Now that we are able to train distributed embeddings, how can we take the trained model and optimize it for inference? Inference is typically very sensitive to **performance and size of the model**. Running just the trained model in a Python environment is incredibly inefficient. There are two key differences between inference and training environments:\n",
        "* **Quantization**: Inference models are typically quantized, where model parameters lose precision for lower latency in predictions and reduced model size. For example FP32 (4 bytes) in trained model to INT8 (1 byte) for each embedding weight. This is also necessary given the vast scale of embedding tables, as we want to use as few devices as possible for inference to minimize latency.\n",
        "* **C++ environment**: Inference latency is a big deal, so in order to ensure ample performance, the model is typically ran in a C++ environment (along with situations where we don't have a Python runtime, like on device)\n",
        "\n",
        "TorchRec provides primitives for converting a TorchRec model into being inference ready with:\n",
        "* APIs for quantizing the model, introducing optimizations automatically with FBGEMM TBE\n",
        "* sharding embeddings for distributed inference\n",
        "* compiling the model to [TorchScript](https://pytorch.org/docs/stable/jit.html) (compatible in C++)\n",
        "\n",
        "In this section, we will go over this entire workflow of:\n",
        "* Quantizing the model\n",
        "* Sharding the quantized model\n",
        "* Compiling the sharded quantized model into TorchScript"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "aae8ef10-f7a4-421a-b71c-f177ff74e96a",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "outputsInitialized": true,
        "executionStartTime": 1726000156892,
        "executionStopTime": 1726000157069,
        "serverExecutionDuration": 7.4504055082798,
        "requestMsgId": "aae8ef10-f7a4-421a-b71c-f177ff74e96a",
        "customOutput": null,
        "output": {
          "id": "456742254014129"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8JypsUNmcspH",
        "outputId": "0a745234-d316-4850-d84a-f08b0f045595"
      },
      "source": [
        "ebc"
      ],
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "EmbeddingBagCollection(\n",
              "  (embedding_bags): ModuleDict(\n",
              "    (product_table): EmbeddingBag(4096, 64, mode='sum')\n",
              "    (user_table): EmbeddingBag(4096, 64, mode='sum')\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 37
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "30694976-da54-48d6-922e-ca53f22c385f",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000157071,
        "executionStopTime": 1726000157317,
        "serverExecutionDuration": 2.9501467943192,
        "requestMsgId": "30694976-da54-48d6-922e-ca53f22c385f",
        "outputsInitialized": true,
        "customOutput": null,
        "id": "t2plfyrWcspH"
      },
      "source": [
        "class InferenceModule(torch.nn.Module):\n",
        "    def __init__(self, ebc: torchrec.EmbeddingBagCollection):\n",
        "        super().__init__()\n",
        "        self.ebc_ = ebc\n",
        "\n",
        "    def forward(self, kjt: KeyedJaggedTensor):\n",
        "        return self.ebc_(kjt)"
      ],
      "execution_count": 38,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "2a4a83f1-449d-493e-8f24-7c1975ecad9d",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000157320,
        "executionStopTime": 1726000157494,
        "serverExecutionDuration": 3.8229525089264,
        "requestMsgId": "2a4a83f1-449d-493e-8f24-7c1975ecad9d",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1619365005294308"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5FRioGEmcspH",
        "outputId": "e4d9efd3-2427-4602-bb28-2f30c4f3f985"
      },
      "source": [
        "module = InferenceModule(ebc)\n",
        "for name, param in module.named_parameters():\n",
        "    # Here, the parameters should still be FP32, as we are using a standard EBC\n",
        "    # FP32 is default, regularly used for training\n",
        "    print(name, param.shape, param.dtype)"
      ],
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32\n",
            "ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "665352e2-208f-4951-8601-282d036b0e4e",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "OSTy4SU8cspH"
      },
      "source": [
        "### Quantization\n",
        "\n",
        "As you can see above, the normal EBC contains embedding table weights as FP32 precision (32 bits for each weight). Here, we will use the TorchRec inference library to quantize the embedding weights of the model to INT8"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "796919b4-f9dd-4d14-a40e-f20668c8257b",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000157499,
        "executionStopTime": 1726000157696,
        "serverExecutionDuration": 14.22468572855,
        "requestMsgId": "796919b4-f9dd-4d14-a40e-f20668c8257b",
        "customOutput": null,
        "outputsInitialized": true,
        "output": {
          "id": "560049189691202"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oV-KPRqDcspH",
        "outputId": "e91220a1-a26c-4f2a-e7c8-e91a3d56d8dc"
      },
      "source": [
        "from torch import quantization as quant\n",
        "from torchrec.modules.embedding_configs import QuantConfig\n",
        "from torchrec.quant.embedding_modules import (\n",
        "    EmbeddingBagCollection as QuantEmbeddingBagCollection,\n",
        ")\n",
        "\n",
        "\n",
        "quant_dtype = torch.int8\n",
        "\n",
        "\n",
        "qconfig = QuantConfig(\n",
        "    # dtype of the result of the embedding lookup, post activation\n",
        "    # torch.float generally for compatability with rest of the model\n",
        "    # as rest of the model here usually isn't quantized\n",
        "    activation=quant.PlaceholderObserver.with_args(dtype=torch.float),\n",
        "    # quantized type for embedding weights, aka parameters to actually quantize\n",
        "    weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),\n",
        ")\n",
        "qconfig_spec = {\n",
        "    # Map of module type to qconfig\n",
        "    torchrec.EmbeddingBagCollection: qconfig,\n",
        "}\n",
        "mapping = {\n",
        "    # Map of module type to quantized module type\n",
        "    torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,\n",
        "}\n",
        "\n",
        "\n",
        "module = InferenceModule(ebc)\n",
        "\n",
        "# Quantize the module\n",
        "qebc = quant.quantize_dynamic(\n",
        "    module,\n",
        "    qconfig_spec=qconfig_spec,\n",
        "    mapping=mapping,\n",
        "    inplace=False,\n",
        ")\n",
        "\n",
        "\n",
        "print(f\"Quantized EBC: {qebc}\")"
      ],
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Quantized EBC: InferenceModule(\n",
            "  (ebc_): QuantizedEmbeddingBagCollection(\n",
            "    (_kjt_to_jt_dict): ComputeKJTToJTDict()\n",
            "    (embedding_bags): ModuleDict(\n",
            "      (product_table): Module()\n",
            "      (user_table): Module()\n",
            "    )\n",
            "  )\n",
            ")\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "c1fdd88b-73af-47a8-8aec-4f9422051ee7",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000157700,
        "executionStopTime": 1726000157862,
        "serverExecutionDuration": 4.0535479784012,
        "requestMsgId": "c1fdd88b-73af-47a8-8aec-4f9422051ee7",
        "outputsInitialized": true,
        "customOutput": null,
        "id": "fAztesVacspI"
      },
      "source": [
        "kjt = kjt.to(\"cpu\")"
      ],
      "execution_count": 41,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000157865,
        "executionStopTime": 1726000158060,
        "serverExecutionDuration": 9.1104581952095,
        "requestMsgId": "f5f911e8-ab78-4fd7-b4a1-7a545b5bd24b",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "434299789062153"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Wnpwa0TmcspI",
        "outputId": "88007466-88ce-4f6b-b7e8-22d042e5378b"
      },
      "source": [
        "qebc(kjt)"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<torchrec.sparse.jagged_tensor.KeyedTensor at 0x7953dc65db70>"
            ]
          },
          "metadata": {},
          "execution_count": 42
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "99559efa-baaa-4de1-91d3-7899f87fe659",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000158063,
        "executionStopTime": 1726000158228,
        "serverExecutionDuration": 3.4465603530407,
        "requestMsgId": "99559efa-baaa-4de1-91d3-7899f87fe659",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "499581679596627"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UUs5fXNncspI",
        "outputId": "79d814ed-7772-4bc3-ba84-f5f0d0a45e36"
      },
      "source": [
        "# Once quantized, goes from parameters -> buffers, as no longer trainable\n",
        "for name, buffer in qebc.named_buffers():\n",
        "    # The shapes of the tables should be the same but the dtype should be int8 now\n",
        "    # post quantization\n",
        "    print(name, buffer.shape, buffer.dtype)"
      ],
      "execution_count": 43,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8\n",
            "ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "2b1a9c89-b921-4a35-9f64-0c63b09a2579",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "fdM7UihocspI"
      },
      "source": [
        "### Shard\n",
        "\n",
        "Here we perform sharding of the TorchRec quantized model. This is to ensure we are using the performant module through FBGEMM TBE. Here we are using one device to be consistent with training (1 TBE)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "19c18bbb-6376-468a-a6dc-8346d30ceb48",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000158234,
        "executionStopTime": 1726000158552,
        "serverExecutionDuration": 108.51271077991,
        "requestMsgId": "19c18bbb-6376-468a-a6dc-8346d30ceb48",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "882684747065056"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mha4FntncspI",
        "outputId": "40a5557c-531f-4726-f06a-5f73546a8fe0"
      },
      "source": [
        "from torchrec import distributed as trec_dist\n",
        "from torchrec.distributed.shard import _shard_modules\n",
        "\n",
        "\n",
        "sharded_qebc = _shard_modules(\n",
        "    module=qebc,\n",
        "    device=torch.device(\"cpu\"),\n",
        "    env=trec_dist.ShardingEnv.from_local(\n",
        "        1,\n",
        "        0,\n",
        "    ),\n",
        ")\n",
        "\n",
        "\n",
        "print(f\"Sharded Quantized EBC: {sharded_qebc}\")"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sharded Quantized EBC: InferenceModule(\n",
            "  (ebc_): ShardedQuantEmbeddingBagCollection(\n",
            "    (lookups): \n",
            "     InferGroupedPooledEmbeddingsLookup()\n",
            "    (_output_dists): ModuleList()\n",
            "    (embedding_bags): ModuleDict(\n",
            "      (product_table): Module()\n",
            "      (user_table): Module()\n",
            "    )\n",
            "    (_input_dist_module): ShardedQuantEbcInputDist()\n",
            "  )\n",
            ")\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "f00ae63f-0ac4-49c0-93fe-32d7fac76693",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000158555,
        "executionStopTime": 1726000159111,
        "serverExecutionDuration": 345.11629864573,
        "requestMsgId": "f00ae63f-0ac4-49c0-93fe-32d7fac76693",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "876807203893705"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0iBD90t3cspI",
        "outputId": "86a17997-aa50-4427-cf2b-55f4c0aef456"
      },
      "source": [
        "sharded_qebc(kjt)"
      ],
      "execution_count": 45,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<torchrec.sparse.jagged_tensor.KeyedTensor at 0x7953dc65c760>"
            ]
          },
          "metadata": {},
          "execution_count": 45
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "897037bb-9d81-4a33-aea1-de1691217d41",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "08ue1zeVcspI"
      },
      "source": [
        "### Compilation\n",
        "Now we have the optimized eager TorchRec inference model. The next step is to ensure that this model is loadable in C++, as currently it is only runnable in a Python runtime.\n",
        "\n",
        "The recommended method of compilation at Meta is two fold: [torch.fx tracing](https://pytorch.org/docs/stable/fx.html) (generate intermediate representation of model) and converting the result to TorchScript, where TorchScript is C++ compatible."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "bdab6e95-3a71-4c3d-b188-115873f1f5d5",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000159118,
        "executionStopTime": 1726000159308,
        "serverExecutionDuration": 28.788283467293,
        "requestMsgId": "bdab6e95-3a71-4c3d-b188-115873f1f5d5",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "491668137118498"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SRzo1jljcspI",
        "outputId": "e0f94cf0-c5cf-4480-b0d9-d5dd2abb459b"
      },
      "source": [
        "from torchrec.fx import Tracer\n",
        "\n",
        "\n",
        "tracer = Tracer(leaf_modules=[\"IntNBitTableBatchedEmbeddingBagsCodegen\"])\n",
        "\n",
        "graph = tracer.trace(sharded_qebc)\n",
        "gm = torch.fx.GraphModule(sharded_qebc, graph)\n",
        "\n",
        "print(\"Graph Module Created!\")"
      ],
      "execution_count": 46,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Graph Module Created!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "909178d6-4dae-45da-9c39-6827019f53a3",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000159312,
        "executionStopTime": 1726000159490,
        "serverExecutionDuration": 2.2248737514019,
        "requestMsgId": "909178d6-4dae-45da-9c39-6827019f53a3",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1555501808508272"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NsgzbUdHcspI",
        "outputId": "c5b67630-19c7-46df-c0ab-216d24309603"
      },
      "source": [
        "print(gm.code)"
      ],
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embeddingbag_flatten_feature_lengths\")\n",
            "torch.fx._symbolic_trace.wrap(\"torchrec_fx_utils__fx_marker\")\n",
            "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_quant_embedding_kernel__unwrap_kjt\")\n",
            "torch.fx._symbolic_trace.wrap(\"torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference\")\n",
            "\n",
            "def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):\n",
            "    flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt);  kjt = None\n",
            "    _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths)\n",
            "    split = flatten_feature_lengths.split([2])\n",
            "    getitem = split[0];  split = None\n",
            "    to = getitem.to(device(type='cuda', index=0), non_blocking = True);  getitem = None\n",
            "    _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths);  flatten_feature_lengths = None\n",
            "    _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to);  to = None\n",
            "    getitem_1 = _unwrap_kjt[0]\n",
            "    getitem_2 = _unwrap_kjt[1]\n",
            "    getitem_3 = _unwrap_kjt[2];  _unwrap_kjt = None\n",
            "    _tensor_constant0 = self._tensor_constant0\n",
            "    _tensor_constant1 = self._tensor_constant1\n",
            "    bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None);  _tensor_constant0 = _tensor_constant1 = None\n",
            "    _tensor_constant2 = self._tensor_constant2\n",
            "    _tensor_constant3 = self._tensor_constant3\n",
            "    _tensor_constant4 = self._tensor_constant4\n",
            "    _tensor_constant5 = self._tensor_constant5\n",
            "    _tensor_constant6 = self._tensor_constant6\n",
            "    _tensor_constant7 = self._tensor_constant7\n",
            "    _tensor_constant8 = self._tensor_constant8\n",
            "    _tensor_constant9 = self._tensor_constant9\n",
            "    int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1);  _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None\n",
            "    embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32);  int_nbit_split_embedding_codegen_lookup_function = None\n",
            "    to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu'));  embeddings_cat_empty_rank_handle_inference = None\n",
            "    keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1);  to_1 = None\n",
            "    return keyed_tensor\n",
            "    \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000159494,
        "executionStopTime": 1726000160206,
        "serverExecutionDuration": 540.64276814461,
        "requestMsgId": "ec77b6ea-f5b1-4c08-9cb9-93faf6a57532",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "978016470760577"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CjjJLc6pcspI",
        "outputId": "be3a9486-e4b5-43f6-aed0-711e827a0040"
      },
      "source": [
        "scripted_gm = torch.jit.script(gm)\n",
        "print(\"Scripted Graph Module Created!\")"
      ],
      "execution_count": 48,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Scripted Graph Module Created!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "originalKey": "9eb089f1-2771-419d-a48e-3b7330c0a1e4",
        "showInput": true,
        "customInput": null,
        "language": "python",
        "executionStartTime": 1726000160212,
        "executionStopTime": 1726000160395,
        "serverExecutionDuration": 2.8529539704323,
        "requestMsgId": "9eb089f1-2771-419d-a48e-3b7330c0a1e4",
        "outputsInitialized": true,
        "customOutput": null,
        "output": {
          "id": "1020643789855657"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BWKPRaI3cspI",
        "outputId": "273181a2-7c91-4167-e814-4a07b51c6b10"
      },
      "source": [
        "print(scripted_gm.code)"
      ],
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "def forward(self,\n",
            "    kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:\n",
            "  _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths\n",
            "  _1 = __torch__.torchrec.fx.utils._fx_marker\n",
            "  _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt\n",
            "  _3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference\n",
            "  flatten_feature_lengths = _0(kjt, )\n",
            "  _fx_marker = _1(\"KJT_ONE_TO_ALL_FORWARD_BEGIN\", flatten_feature_lengths, )\n",
            "  split = (flatten_feature_lengths).split([2], )\n",
            "  getitem = split[0]\n",
            "  to = (getitem).to(torch.device(\"cuda\", 0), True, None, )\n",
            "  _fx_marker_1 = _1(\"KJT_ONE_TO_ALL_FORWARD_END\", flatten_feature_lengths, )\n",
            "  _unwrap_kjt = _2(to, )\n",
            "  getitem_1 = (_unwrap_kjt)[0]\n",
            "  getitem_2 = (_unwrap_kjt)[1]\n",
            "  _tensor_constant0 = self._tensor_constant0\n",
            "  _tensor_constant1 = self._tensor_constant1\n",
            "  ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)\n",
            "  _tensor_constant2 = self._tensor_constant2\n",
            "  _tensor_constant3 = self._tensor_constant3\n",
            "  _tensor_constant4 = self._tensor_constant4\n",
            "  _tensor_constant5 = self._tensor_constant5\n",
            "  _tensor_constant6 = self._tensor_constant6\n",
            "  _tensor_constant7 = self._tensor_constant7\n",
            "  _tensor_constant8 = self._tensor_constant8\n",
            "  _tensor_constant9 = self._tensor_constant9\n",
            "  int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)\n",
            "  _4 = [int_nbit_split_embedding_codegen_lookup_function]\n",
            "  embeddings_cat_empty_rank_handle_inference = _3(_4, 1, \"cuda:0\", 6, )\n",
            "  to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device(\"cpu\"))\n",
            "  _5 = [\"product\", \"user\"]\n",
            "  _6 = [64, 64]\n",
            "  keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)\n",
            "  _7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )\n",
            "  return keyed_tensor\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "9a1dda10-b9cf-4d9f-b068-51ae3ce3ffc1",
        "showInput": false,
        "customInput": null,
        "language": "markdown",
        "outputsInitialized": false,
        "id": "DQiGRYOgcspI"
      },
      "source": [
        "## Congrats!\n",
        "\n",
        "You have now gone from training a distributed RecSys model all the way to making it inference ready. https://github.com/pytorch/torchrec/tree/main/torchrec/inference has a full example of how to load a TorchRec TorchScript model into C++ for inference."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ebXfh7oW9fHH",
        "originalKey": "4ca6a593-9ac9-4e2f-bc9a-8c8a1887ad41",
        "outputsInitialized": false,
        "language": "markdown",
        "showInput": false
      },
      "source": [
        "## More resources\n",
        "For more information, please see our [dlrm](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/) example, which includes multinode training on the criteo terabyte dataset, using Meta’s [DLRM](https://arxiv.org/abs/1906.00091)."
      ]
    }
  ]
}
